This is a simple example of Gaussian Processes using Python and the scikit-learn library.
Gaussian Processes (GPs) are a non-parametric approach to modeling complex functions. A GP defines a distribution over functions, and it can be used for regression and uncertainty estimation. GPs are particularly useful when the underlying function is unknown, and we want to model both the mean and uncertainty of predictions.
Key concepts of Gaussian Processes:
Gaussian Processes are commonly used in machine learning for tasks such as regression, optimization, and Bayesian optimization.
Python Source Code:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
# Define a true underlying function
def true_function(X):
return np.sin(3 * X) + X**2 - 2 * X
# Generate synthetic data
np.random.seed(42)
X = np.sort(5 * np.random.rand(100, 1), axis=0)
y = true_function(X).ravel() + 0.5 * np.random.normal(size=X.shape[0])
# Define the kernel (Radial Basis Function with constant term)
kernel = C(1.0, (1e-3, 1e3)) * RBF(1.0, (1e-2, 1e2))
# Create Gaussian Process Regressor with the defined kernel
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10)
# Fit the Gaussian Process to the data
gp.fit(X, y)
# Generate test data
X_test = np.linspace(0, 5, 1000)[:, np.newaxis]
# Predict mean and standard deviation of the GP at test points
y_pred, sigma = gp.predict(X_test, return_std=True)
# Plot the true function, observed data, and GP predictions with uncertainty
plt.figure(figsize=(10, 6))
plt.plot(X_test, true_function(X_test), 'k:', label='True Function')
plt.errorbar(X, y, 0.5, fmt='r.', markersize=10, label='Observations')
plt.plot(X_test, y_pred, 'b-', label='GP Prediction')
plt.fill_between(X_test.ravel(), y_pred - 1.96 * sigma, y_pred + 1.96 * sigma, alpha=0.2, color='blue')
plt.title('Gaussian Process Regression')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.show()
Explanation: